
from argparse import ArgumentParser

from metric.myMetrics import Metric

import glob
import os
import re
import logging
import random
import pickle
import json
import shutil
import torch
from torch.utils.data import Dataset, DataLoader , SequentialSampler
from torch.nn.utils.rnn import pad_sequence

import numpy as np

from tqdm import tqdm, trange

from model.vallia_vae import BlenderbotSmallForConditionalGeneration

from model.vallina_configuration import BlenderbotSmallConfig
from src.transformers import BlenderbotSmallTokenizer

from src.transformers import (
    AdamW,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)



# Configs
logger = logging.getLogger(__name__)

class InputFeatures_train(object):
    def __init__(self, conv_id, input_ids, position_ids, token_type_ids,
                 role_ids, lm_labels, cls_position, cls_label, strategy_ids, emotion_ids, input_len=None):
        self.conv_id = conv_id
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.token_type_ids = token_type_ids
        self.role_ids = role_ids
        self.lm_labels = lm_labels
        self.cls_position = cls_position
        self.cls_label = cls_label
        self.strategy_ids = strategy_ids
        self.emotion_ids = emotion_ids
        
        if input_len is None:
            self.input_len = len(input_ids)
        else:
            self.input_len = input_len

class InputFeatures_blender(object):
    def __init__(self, encoder_feature, decoder_feature):
        
        self.conv_id = encoder_feature.conv_id
        self.input_ids = encoder_feature.input_ids
        self.position_ids = encoder_feature.position_ids
        self.token_type_ids = encoder_feature.token_type_ids
        self.role_ids = encoder_feature.role_ids
        self.lm_labels = encoder_feature.lm_labels
        self.cls_position = encoder_feature.cls_position
        self.cls_label = encoder_feature.cls_label
        self.strategy_ids = encoder_feature.strategy_ids
        self.emotion_ids = encoder_feature.emotion_ids
        
        self.decoder_input_ids = decoder_feature.input_ids
        self.decoder_position_ids = decoder_feature.position_ids
        self.decoder_token_type_ids = decoder_feature.token_type_ids
        self.decoder_role_ids = decoder_feature.role_ids
        self.decoder_lm_labels = decoder_feature.lm_labels
        self.decoder_cls_position = decoder_feature.cls_position
        self.decoder_cls_label = decoder_feature.cls_label
        self.decoder_strategy_ids = decoder_feature.strategy_ids
        self.decoder_emotion_ids = decoder_feature.emotion_ids

def _norm_text(text, add_special):
    emo, r, t, *toks = text.strip().split()
    try:
        emo = 0 # int(emo)
        r = int(r)
        t = int(t)
        toks = ' '.join(toks[:len(toks)])
        
        if add_special:
            toks = '[CLS] ' + toks
            
    except Exception as e:
        raise e
    return emo, r, t, toks

def _remove_special_tokens(text):
    # 定义要删除的标识符列表
    special_tokens = ["[Question]", "[Reflection of feelings]", "[Information]", "[Restatement or Paraphrasing]", "[Others]", "[Self-disclosure]", "[Affirmation and Reassurance]", "[Providing Suggestions]"]  # 可以添加其他标识符

    # 构建正则表达式模式，将所有标识符替换为空字符串
    pattern = "|".join(re.escape(token) for token in special_tokens)
    text_without_tokens = re.sub(pattern, "", text)

    return text_without_tokens

def _extract_label(text):
    # 构建正则表达式模式，匹配所有方括号内的内容
    pattern = r"(\[[^\]]+\])"
    matches = re.findall(pattern, text)
    if len(matches) >= 2:  # 检查是否至少有两个匹配
        return matches[1]  # 返回第二个匹配的内容，包括方括号
    else:
        return None

def _get_input_from_text(text, tokenizer, utter_emotion, add_special, strategy = True,  cls = False):
    
    srcs = text.strip().split("EOS")
    
    inputs = []
    roles = []
    turns = []
    strategy_labels = []
    utter_emo = []
    
    emotion = None
    
    strategy_map = {"[Question]": 0, "[Reflection of feelings]": 1, "[Information]": 2, "[Restatement or Paraphrasing]": 3, "[Others]": 4, "[Self-disclosure]": 5, "[Affirmation and Reassurance]": 6, "[Providing Suggestions]": 7}
    
    emotion_map = {'angry': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'sadness': 4, 'neutral': 5}
    
    for idx, src in enumerate(srcs):
        
        if src == "":
            continue
        
        src_emo, src_role, src_turn, src = _norm_text(src, add_special)
        
        if src_role == 0: #seeker utterance
            context_id = tokenizer.encode(src)
            
        else:
            # utt = src.split("]")[1] #support utterance
            utt = _remove_special_tokens(src)
            context_id  = tokenizer.encode(utt)
            
        if emotion is None:
            emotion = src_emo
        
        #目前通过这种方式 把seeker 和 supporte 的emtion全部给提取了
        #在这里 考虑到 supportorde情绪不需要生成 所以并没有给过tokenizer.encode   
        if utter_emotion is not None:
            utter_emo.append(emotion_map[utter_emotion[idx]])
        
        if not add_special:#decoder的输入
            if not strategy:
                context_id = [i for i in context_id if i < 54944]  #cls 54944 sep 54945
            elif cls:
                context_id = tokenizer.cls + [i for i in context_id if i< 54944]
            else:
                pass
        
        if src_role == 1:
            try:
                if not add_special:#decoder的输入
                    label = "["+src.split("[")[1].split("]")[0]+"]"
                else:
                    label = _extract_label(src)
            except Exception as e:
                strategy_labels.append(8)
            else:
                strategy_labels.append(strategy_map[label]) # Strategy_map tokenizer.encode([label])[0] - 54944
        else:
            strategy_labels.append(8)
            
        inputs.append(context_id)
        roles.append(src_role)
        turns.append(src_turn)
        
    return inputs, roles, turns, strategy_labels, utter_emo, emotion
    
# def _Emotion_tag_to_tag_idx():
#     return {'angry': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'sadness': 4, 'neutral': 5}

def _make_feature_copy(args, id_, sents, utt_emo_label, rls, ts, cls, bos, eos, pad=False, block_size=512, strategy_labels=None, evaluate=False, str_embd=False, decoder=False):
     
    if len(sents) == 0:
        return InputFeatures_train([], [], [], [], [],
                            [], [] , [], [], [])
        
    input_ids = [i for s in sents for i in s+[eos]] #构建输入序列
        
    input_ids = input_ids
    lm_labels = []
    token_type_ids = []
    roles = []
    strategy_ids = []
    emotion_ids = []
    
    for i, s in enumerate(sents): #初始化一系列标签和其他特征
        token_type_ids += [ts[i]] * (len(s)+1)
        
        if str_embd: #use for strategy embed but currently we treat strategy as token
            strategy_ids += [strategy_labels[-1]] * (len(s) + 1)
        else:
            strategy_ids += [8] * (len(s) + 1)
        
        if i < len(sents) - 1: #读取最后一个句子作为自回归的标签
            lm_labels += [-100] * (len(s) + 1)
            roles += [rls[i]] * (len(s) + 1)
            emotion_ids += [utt_emo_label[i]] * (len(s) + 1)
        else:
            lm_labels += (s + [eos])
            roles += [rls[i]] * (len(s) + 1)
            emotion_ids += [utt_emo_label[i]] * (len(s) + 1)
        
        i = len(lm_labels) - 1
        
    if len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)
    
    while i >= 0:
        if lm_labels[i] != -100:
            break
        i -= 1
    
    input_ids = input_ids[:i+1]
    lm_labels = lm_labels[:i+1]
    token_type_ids = token_type_ids[:i+1]
    roles = roles[:i+1]
    emotion_ids = emotion_ids[:i+1]
    
    if not str_embd:
        strategy_ids = [8]*len(input_ids) # strategy is not used
    else:
        strategy_ids = strategy_ids[:i+1]
    
    if len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)

    assert (len(input_ids) == len(token_type_ids)
            == len(lm_labels) == len(roles) == len(strategy_ids) == len(emotion_ids))
    
    # cut according to block size
    if len(input_ids) > block_size:
        cut_index = input_ids.index(eos,-512) + 1
        input_ids = input_ids[cut_index: ]

        token_type_ids = token_type_ids[cut_index: ]
        lm_labels = lm_labels[cut_index: ]
        roles = roles[cut_index: ]
        strategy_ids = strategy_ids[cut_index: ]
        emotion_ids = emotion_ids[cut_index: ]
    
    # pad to multiples of 8
    if pad:
        while len(input_ids) % 8 != 0:
            input_ids.append(0)
            token_type_ids.append(0)
            lm_labels.append(-100)
            roles.append(0)
            emotion_ids.append(6)
            strategy_ids.append(8)
        assert len(input_ids) % 8 == 0
        
    position_ids = list(range(len(input_ids)))
    assert (len(input_ids) == len(position_ids) == len(token_type_ids) == len(lm_labels) == len(roles) == len(strategy_ids) == len(emotion_ids))
    
    if len(input_ids) == 0:
        import pdb
        pdb.set_trace()
    elif len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)
    if True:
        # if it is for generation, the last sentence of context is the last sentence
        cls_position = len(input_ids)-1-input_ids[::-1].index(eos)
    else:
        # if not, the last sentence of context is the second last sentence
        cls_position = len(input_ids)-1-input_ids[::-1].index(eos,input_ids[::-1].index(eos)+1)
    if evaluate and strategy_labels[-1]!=8:
        try:
            lm_labels[lm_labels.index(strategy_labels[-1]+50257+4687+1)] = -100 
        except Exception:
            pass

    feature = InputFeatures_train(id_, input_ids, position_ids, token_type_ids, roles,
                            lm_labels, cls_position , strategy_labels[-1], strategy_ids, emotion_ids)
    return feature

#这里没有完全理解
def _make_feature(args, id_, sents, utt_emo_label, rls, ts, cls, bos, eos, pad=False, strategy_labels=None, evaluate=False,  decoder=False, block_size=512, str_embd=False):
     
    if len(sents) == 0:
        return InputFeatures_train([], [], [], [], [],
                            [], [] , [], [], [])
        
    input_ids = [i for s in sents for i in s+[eos]] #构建输入序列
        
    input_ids = input_ids
    lm_labels = []
    token_type_ids = []
    roles = []
    strategy_ids = []
    emotion_ids = []
    
    for i, s in enumerate(sents): #初始化一系列标签和其他特征
        token_type_ids += [ts[i]] * (len(s)+1)
        
        if str_embd: #use for strategy embed but currently we treat strategy as token
            strategy_ids += [strategy_labels[-1]] * (len(s) + 1)
        else:
            strategy_ids += [8] * (len(s) + 1)
        
        if i < len(sents) - 1: #读取最后一个句子作为自回归的标签
            lm_labels += [-100] * (len(s) + 1)
            roles += [rls[i]] * (len(s) + 1)
            emotion_ids += [utt_emo_label[i]] * (len(s) + 1)
        else:
            lm_labels += (s + [eos])
            roles += [rls[i]] * (len(s) + 1)
            emotion_ids += [utt_emo_label[i]] * (len(s) + 1)
        
        i = len(lm_labels) - 1
        
    if len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)
    
    while i >= 0:
        if lm_labels[i] != -100:
            break
        i -= 1
    
    input_ids = input_ids[:i+1]
    lm_labels = lm_labels[:i+1]
    token_type_ids = token_type_ids[:i+1]
    roles = roles[:i+1]
    emotion_ids = emotion_ids[:i+1]
    
    if not str_embd:
        strategy_ids = [8]*len(input_ids) # strategy is not used
    else:
        strategy_ids = strategy_ids[:i+1]
    
    if len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)

    assert (len(input_ids) == len(token_type_ids)
            == len(lm_labels) == len(roles) == len(strategy_ids) == len(emotion_ids))
    
    # cut according to block size
    if len(input_ids) > block_size:
        cut_index = input_ids.index(eos,-512) + 1
        input_ids = input_ids[cut_index: ]

        token_type_ids = token_type_ids[cut_index: ]
        lm_labels = lm_labels[cut_index: ]
        roles = roles[cut_index: ]
        strategy_ids = strategy_ids[cut_index: ]
        emotion_ids = emotion_ids[cut_index: ]
    
    # pad to multiples of 8
    if pad:
        while len(input_ids) % 8 != 0:
            input_ids.append(0)
            token_type_ids.append(0)
            lm_labels.append(-100)
            roles.append(-1) # 2 roles
            emotion_ids.append(6)  #6 emotion categroy
            strategy_ids.append(8) #8 strategy categroy
        assert len(input_ids) % 8 == 0
        
    position_ids = list(range(len(input_ids)))
    assert (len(input_ids) == len(position_ids) == len(token_type_ids) == len(lm_labels) == len(roles) == len(strategy_ids) == len(emotion_ids))
    
    if len(input_ids) == 0:
        import pdb
        pdb.set_trace()
    elif len(input_ids) == 1:
        print(input_ids, lm_labels, token_type_ids, roles, emotion_ids)
    if True:
        # if it is for generation, the last sentence of context is the last sentence
        cls_position = len(input_ids)-1-input_ids[::-1].index(eos)
    else:
        # if not, the last sentence of context is the second last sentence
        cls_position = len(input_ids)-1-input_ids[::-1].index(eos,input_ids[::-1].index(eos)+1)
    if evaluate and strategy_labels[-1]!=8:
        try:
            lm_labels[lm_labels.index(strategy_labels[-1]+50257+4687)] = -100 
        except Exception:
            pass

    feature = InputFeatures_train(id_, input_ids, position_ids, token_type_ids, roles,
                            lm_labels, cls_position , strategy_labels[-1], strategy_ids, emotion_ids)
    return feature


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)        

def construct_conv_ESC(args, idx, row, utt_emotion, tokenizer, evaluate=False, strategy=True):
    
    #process input (encoder input) text
    inputs, roles, turns, strategy_labels, utt_emo_label, _ = _get_input_from_text("EOS".join(row.split("EOS")[:-1]), tokenizer, 
                                                                                   utt_emotion, add_special=True, strategy=strategy)
    
    #process output (decoder input) text
    d_inputs, d_roles, d_turns, d_strategy_labels, d_utt_emo_label, _ = _get_input_from_text(row.split("EOS")[-1], tokenizer, 
                                                                                             utt_emotion, add_special=False, strategy=strategy)
    
    #make featurs for input text
    feature = _make_feature(args, idx, inputs, utt_emo_label, roles, turns, tokenizer.encode(tokenizer.cls_token)[0], tokenizer.bos_token_id, tokenizer.eos_token_id, pad=True, strategy_labels=strategy_labels, evaluate=evaluate, decoder=False)
    
    # make feature for_make_feature output (decoder input) text
    d_feature = _make_feature(args, idx, d_inputs, d_utt_emo_label, d_roles, d_turns, tokenizer.encode(tokenizer.cls_token)[0], tokenizer.bos_token_id,  tokenizer.eos_token_id, strategy_labels=d_strategy_labels, evaluate=evaluate, decoder=True)
    
    feature = InputFeatures_blender(feature, d_feature)
    return feature

def process_row_to_comet_query(row):
    sents = row.strip().split('EOS')
    n_sent = len(sents)
    all_seeker_uttrs = []
    for i in range(n_sent-1, -1, -1):
        tokens = sents[i].strip().split(' ')
        if int(tokens[1]) == 0:
            if int(tokens[1]) == 0:
                return ' '.join(tokens[3:])
            
def summary(test_file_path, generate_file_path, reference_file_path, summary_file_path, chat_texts, test_situation_file_path):
    
    with open(test_file_path, "r", encoding="utf-8") as f:
        ctx = f.read().split("\n")
    with open(test_situation_file_path, "r", encoding="utf-8") as f:
        st = f.read().split("\n")
    ctx = ctx[:-1]
    st = st[:-1]
    with open(generate_file_path, "r", encoding="utf-8") as f:
        gen_rep = json.load(f)
    with open(reference_file_path, "r", encoding="utf-8") as f:
        ref_rep = json.load(f)
    with open(summary_file_path, 'w', encoding='utf-8') as f:
        for (ctx_row, ref_rep_row, gen_rep_row, chat_text, st_row) in zip(ctx, ref_rep, gen_rep, chat_texts, st):
            query = process_row_to_comet_query(chat_text)
            if query is None:
                query = ""
            
            context = ctx_row.split(' EOS')
            utterances = []
            for item in context[:-1]:
                _, src_role, _, src = _norm_text(item)
                if src_role == 0:
                    utt = src
                else:
                    utt = src.split("] ")[1]
                utterances.append(utt)
            
            line = '\t'.join(utterances) + '\t' + gen_rep_row + '\n'
            # line = '[contxt]\t' + ctx_row + '\n[reference_response]\t' + ref_rep_row + '\n[hypothesis_response]\t' + gen_rep_row + '\n[comet query]\t' + query +'\n[situation]\t' + st_row + '\n[situation comet blocks (attention top5)]\t' + '\n' * 2
            
            f.writelines(line) 
   
class ESCDataset(Dataset):
    def __init__(self,  tokenizer: PreTrainedTokenizer, args, df, utter_emo, block_size=512, evaluate=False, strategy=True, test=False):
        # super(self, ESCDataset).__init__()
        
        block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)
        self.tokenizer = tokenizer
        directory = args.data_cache_dir
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        if evaluate:
            if not test:
                 cached_features_file = os.path.join(
                    directory, "val_cached_lm_" + str(block_size)
                )
            else:
                cached_features_file = os.path.join(
                    directory, "test_cached_lm_" + str(block_size)
                )
        else:
            cached_features_file = os.path.join(
                directory, "trn_cached_lm_" + str(block_size)
            )
        
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s", cached_features_file)
            
            with open(cached_features_file, 'rb') as handle:
                self.features = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)
            
            self.features = []
            for idx, (row, emo) in enumerate(zip(df[:-1], utter_emo)):
                conv = construct_conv_ESC(args, idx, row, emo['emotions'], tokenizer, strategy=strategy, evaluate=evaluate)
                self.features.append(conv)
              
            #考虑到数据形式还不完全确定 暂时先把保存为pkl注释掉  
            logger.info("Saving features into cached file %s", cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.features, handle, protocol=pickle.HIGHEST_PROTOCOL)
            logger.info("Saving finished~")
            
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, item):
        return self.features[item]
    
    @staticmethod
    def collate(features):
        input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long) for f in features], batch_first=True, padding_value=0)
        
        position_ids = pad_sequence([torch.tensor(f.position_ids,
                                                  dtype=torch.long)
                                     for f in features],
                                    batch_first=True, padding_value=0)
        token_type_ids = pad_sequence([torch.tensor(f.token_type_ids, dtype=torch.long) for f in features],
                                      batch_first=True, padding_value=0)
        
        role_ids = pad_sequence([torch.tensor(f.role_ids, dtype=torch.long) for f in features],
                                     batch_first=True, padding_value=-1)
        
        labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) for f in features],
                              batch_first=True, padding_value=-100)
        
        cls_positions = torch.tensor([f.cls_position for f in features], dtype=torch.long)
        
        cls_labels = torch.tensor([f.cls_label for f in features], dtype=torch.long)
        
        strategy_ids = pad_sequence([torch.tensor(f.strategy_ids, dtype=torch.long) for f in features], batch_first=True, padding_value=8)
        
        emotion_ids = pad_sequence([torch.tensor(f.emotion_ids, dtype=torch.long)
                               for f in features],
                              batch_first=True, padding_value=6)

        decoder_input_ids = pad_sequence([torch.tensor(f.decoder_input_ids, dtype=torch.long)
                                  for f in features],
                                 batch_first=True, padding_value=0)
        decoder_position_ids = pad_sequence([torch.tensor(f.decoder_position_ids,
                                                  dtype=torch.long)
                                     for f in features],
                                    batch_first=True, padding_value=0)
        decoder_token_type_ids = pad_sequence([torch.tensor(f.decoder_token_type_ids,
                                                    dtype=torch.long)
                                       for f in features],
                                      batch_first=True, padding_value=0)
        
        decoder_role_ids = pad_sequence([torch.tensor(f.decoder_role_ids,
                                              dtype=torch.long)
                                      for f in features],
                                     batch_first=True, padding_value=0)
        decoder_labels = pad_sequence([torch.tensor(f.decoder_lm_labels, dtype=torch.long)
                               for f in features],
                              batch_first=True, padding_value=-100)

        decoder_cls_positions = torch.tensor([f.decoder_cls_position for f in features], dtype=torch.long)

        decoder_cls_labels = torch.tensor([f.decoder_cls_label for f in features], dtype=torch.long)

        decoder_strategy_ids = pad_sequence([torch.tensor(f.decoder_strategy_ids, dtype=torch.long)
                               for f in features], batch_first=True, padding_value=8)
        
        return (input_ids, position_ids, token_type_ids, role_ids, labels, cls_positions, cls_labels, strategy_ids, emotion_ids, decoder_input_ids, decoder_position_ids, decoder_token_type_ids, decoder_role_ids, decoder_labels, decoder_cls_positions, decoder_cls_labels, decoder_strategy_ids)
        
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False):
    
    ordering_and_checkpoint_path = []
    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
        
def load_and_cache_examples(args, tokenizer, df, emo, evaluate=False, strategy=True, test=False):
    return ESCDataset(tokenizer, args, df, emo, evaluate=evaluate, strategy=strategy, test=test)

# 指数增长策略
# def get_kl_weight(steps, warmup_steps=706, max_value=100.0, growth_rate=5.0):
#     # 使用指数增长策略
#     return min((torch.exp(torch.tensor((steps + 1) / warmup_steps * growth_rate)) - 1) / (torch.exp(torch.tensor(growth_rate)) - 1) * max_value, max_value)

#线性增长 从0增加到100 并且停留在100        
def get_kl_weight(steps, warmup_steps= 706, base_value=1e4):  #num_example // batch size  1e4效果最好
    # return min((steps + 1) / warmup_steps, max_value)
    return base_value + (steps + 1) / warmup_steps
    
#Training  of the model
def train(args, train_dataset,model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
    
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    
    train_dataloader  = DataLoader(train_dataset, batch_size=args.train_batch_size, collate_fn=ESCDataset.collate, shuffle=False, drop_last=False)
    
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        
    model = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))
    
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    other = ["trans_interact", "encoder_attn_emotion", "fusion"]
    no_main = no_decay + other
    
    params = list(model.named_parameters())
    optimizer_grouped_parameters = [
        {'params':[p for n, p in params if not any(nd in n for nd in no_main)], 'weight_decay': args.weight_decay, 'lr':2e-5},
        {'params':[p for n, p in params if not any(nd in n for nd in other) and any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': 2e-5},
        {'params':[p for n, p in params if any(nd in n for nd in other) and any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': 5e-5},
        {'params':[p for n, p in params if any(nd in n for nd in other) and not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay,'lr': 5e-5},
    ]
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
    
    # Check if saved optimizer or scheduler states exist
    if False and (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)  
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        ).to(args.device)
        
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    
    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
     
    tr_loss, logging_loss, tr_ppl, logging_lm_loss, tr_emo_loss, \
    logging_emo_loss, tr_strategy_loss, logging_strategy_loss, tr_intensity_loss, logging_intensity_loss = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    best_ppl = 1e8
    
    tr_loss1, logging_loss1  = 0.0, 0.0
    
    model.zero_grad()
   
    train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=True)
    set_seed(args)  # Added here for reproducibility
    
    import numpy as np
    np.set_printoptions(threshold=np.inf)
    
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Training Iteration", disable=False)
        model.train()
        
        for step, batch in enumerate(epoch_iterator):
            
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            
            input_ids, position_ids, token_type_ids, role_ids, labels, cls_positions, cls_labels, strategy_ids, emotion_ids, decoder_input_ids, decoder_position_ids, decoder_token_type_ids, decoder_role_ids, decoder_labels, decoder_cls_positions, decoder_cls_labels, decoder_strategy_ids = batch

            decoder_strategy_ids = decoder_strategy_ids[:, 0]
            decoder_strategy_ids = decoder_strategy_ids.to(args.device)
            
            model.train()
            
            if input_ids.shape[1] > 512: continue
            
            input_ids = input_ids.to(args.device)
            # turn_ids = turn_ids.to(args.device)
            role_ids = role_ids.to(args.device)
            decoder_input_ids = decoder_input_ids.to(args.device)
            # decoder_turn_ids = decoder_turn_ids.to(args.device)
            decoder_label_ids = decoder_labels.to(args.device)
            decoder_role_ids = decoder_role_ids.to(args.device)
            
            strategy_ids = strategy_ids.to(args.device)
            emotion_ids = emotion_ids.to(args.device)
            
            if not args.role:
                role_ids = None
                decoder_role_ids = None
                
            if not args.turn:
                turn_ids = None
                
            if not args.strategy:
                outputs = model(input_ids, attention_mask = input_ids.ne(tokenizer.pad_token_id), strategy_ids=strategy_ids, emotion_ids=emotion_ids, decoder_input_ids=decoder_input_ids, decoder_role_ids=decoder_role_ids, role_ids=role_ids, labels = decoder_label_ids)
                ppl = outputs[0]
                eps_all_kl_loss = outputs[1]
                po_rec_all_loss = outputs[2]
                causal_loss = outputs[3]
                
                kl_weight = get_kl_weight(step)
                loss = args.ppl_weight * ppl + args.causal_weight * causal_loss + args.rec_weight * po_rec_all_loss + kl_weight * eps_all_kl_loss
                
                # print("causal_loss", causal_loss)
                # print("po_rec_all_loss", po_rec_all_loss)
                # print("eps_all_kl_loss", kl_weight*eps_all_kl_loss)
                
            else:
                outputs = model(input_ids, attention_mask = input_ids.ne(tokenizer.pad_token_id), strategy_ids=strategy_ids, emotion_ids=emotion_ids, decoder_input_ids=decoder_input_ids, decoder_role_ids=decoder_role_ids, role_ids=role_ids, labels = decoder_label_ids)
                ppl = outputs.loss 
                eps_all_kl_loss = outputs[1]
                po_rec_all_loss = outputs[2]
                loss = ppl + po_rec_all_loss + args.beta_kl * eps_all_kl_loss  
                
            if not args.no_cuda and args.n_gpu >= 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                ppl = ppl.mean()
            
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()   
                
            tr_loss += loss.item()
            tr_ppl += ppl.item()
            
            if args.strategy:
                tr_loss1 += outputs[1].mean().item()
                
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                
                global_step += 1
                
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0 and global_step > t_total*0.0:
                    
                    model.eval()
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        
                        eval_results = evaluate(args, model, tokenizer, args.eval_dataset, "{}-{}".format("checkpoint", global_step), eval_or_test="eval")
                        logger.info("eval_results in Training %s", eval_results)
                        
                        test_results = evaluate(args, model, tokenizer, args.test_dataset, "{}-{}".format("checkpoint", global_step), eval_or_test="test")
                        logger.info("test_results in Training %s", test_results)
                        
                    model.train()
                    logger.info("lr: %f, step: %d, loss: %f, ppl: %f ", scheduler.get_lr()[0], global_step, (tr_loss - logging_loss) / args.logging_steps, (tr_loss1- logging_loss1) / args.logging_steps)
                    logging_loss = tr_loss
                    logging_ppl = tr_ppl
                    logging_loss1 = tr_loss1
                    
                    if eval_results['eval_perplexity'] < best_ppl:
                        best_ppl = eval_results['eval_perplexity']

                        if args.save:
                            checkpoint_prefix = "checkpoint"

                            output_dir = args.output_dir
                            os.makedirs(output_dir, exist_ok=True) #如果不存在就创建他
                            
                            model_to_save = (
                                model.module if hasattr(model, "module") else model
                            )  # Take care of distributed/parallel training
                            
                            model_to_save.save_pretrained(output_dir)
                            tokenizer.save_pretrained(output_dir)

                            torch.save(args, os.path.join(output_dir, "training_args.bin"))
                            logger.info("Saving model checkpoint to %s", output_dir)

                            _rotate_checkpoints(args, checkpoint_prefix)
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            logger.info("Saving optimizer and scheduler states to %s", output_dir)
            
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
        
    print("Train finished~")
    return global_step, tr_loss / global_step

def evaluate(args, model, tokenizer, eval_dataset,  prefix="", eval_or_test="eval"):
    
    eval_output_dir = args.output_dir
    os.makedirs(eval_output_dir, exist_ok=True)
    
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)    
    
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=ESCDataset.collate, drop_last=False)
    
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model = amp.initialize(model, opt_level=args.fp16_opt_level)

    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    import numpy as np
    strategy_probs = []
    cls_labels_list = []
    num_samples = []
    
    
    for batch in tqdm(eval_dataloader, desc="Evaluating", disable=False):  
        
        input_ids, position_ids, token_type_ids, role_ids, labels, cls_positions, cls_labels, strategy_ids, emotion_ids, decoder_input_ids, decoder_position_ids, decoder_token_type_ids, decoder_role_ids, decoder_labels, decoder_cls_positions, decoder_cls_labels, decoder_strategy_ids = batch
        
        decoder_strategy_ids = decoder_strategy_ids[:, 0]
        decoder_strategy_ids = decoder_strategy_ids.to(args.device)
    
        if input_ids.shape[1] > 1024: continue
        input_ids = input_ids.to(args.device)
        # turn_ids = turn_ids.to(args.device)
        role_ids = role_ids.to(args.device)
        decoder_input_ids = decoder_input_ids.to(args.device)
        # decoder_turn_ids = decoder_turn_ids.to(args.device)
        decoder_label_ids = decoder_labels.to(args.device)
        decoder_role_ids = decoder_role_ids.to(args.device)
        decoder_cls_labels = decoder_cls_labels.to(args.device)
        strategy_ids = strategy_ids.to(args.device)
        emotion_ids = emotion_ids.to(args.device)
        
        if not args.role:
            role_ids = None
            decoder_role_ids = None
            
        if not args.turn:
            turn_ids = None
            decoder_role_ids = None   
            
        with torch.no_grad():
            if not args.role:
                role_ids = None
                
            if not args.turn:
                turn_ids = None
                
            if args.strategy:
                outputs = model(input_ids, attention_mask = input_ids.ne(tokenizer.pad_token_id), strategy_ids=strategy_ids, emotion_ids=emotion_ids, decoder_input_ids=decoder_input_ids, decoder_role_ids=decoder_role_ids, role_ids=role_ids, labels = decoder_label_ids)
                loss = ppl = outputs.loss
            else:
                outputs = model(input_ids, strategy_ids=strategy_ids, emotion_ids=emotion_ids, decoder_input_ids=decoder_input_ids, decoder_role_ids=decoder_role_ids, turn_ids=turn_ids, role_ids=role_ids, labels=decoder_label_ids)
                ppl = loss = outputs[0]
                
            if args.strategy:  
                cls_labels_list.extend(decoder_cls_labels.cpu().numpy().tolist())         
                strategy_probs.append(torch.nn.functional.softmax(outputs.logits[0, 0, 54945:54945+8], dim=-1).cpu().numpy().tolist())
                
            lm_loss = outputs[0]       
            num_samples.append((decoder_label_ids.cpu().numpy() != -100).astype(np.int).sum())
            eval_loss += lm_loss.sum().item() * (decoder_label_ids.cpu().numpy() != -100).astype(np.int).sum()
            
    nb_eval_steps += 1
            
    eval_loss = eval_loss / sum(num_samples)
    perplexity = torch.exp(torch.tensor(eval_loss)).item()

    result = {"eval_perplexity": perplexity}
    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
        
    with open(output_eval_file, "a+") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("***** Eval results {} *****".format(prefix)+"\n")
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result
                  
def main(args):
    
    torch.autograd.set_detect_anomaly(True)
    
    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            args.model_name_or_path = sorted_checkpoints[-1]
            
    if(
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
        and not args.should_continue
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
                )
        )
    
    # Setup CUDA, GPU & distributed training
    if not args.no_cuda:
        # distributed training
        # device = torch.device("cuda")
        # args.n_gpu = torch.cuda.device_count()
        # args.device = device
        
        #sigle GPU
        n_gpu = 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        args.n_gpu, args.device = n_gpu, device
    else:
        device = torch.device("cpu")
        args.device = device
        args.n_gpu = 0
        
    #Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
    
    #Setup seed
    set_seed(args)
    
    #Load Pre_trained model
    tokenizer = BlenderbotSmallTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.model_cache_dir)
    tokenizer.add_special_tokens({'cls_token': '[CLS]'})
    model = BlenderbotSmallForConditionalGeneration.from_pretrained(args.model_name_or_path, cache_dir=args.model_cache_dir)
    
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    
    #Load dataset
    logger.info("Training/evaluation parameters %s", args)  #print args
    
    with open(args.data_path + "/" + args.train_file_name, "r", encoding='utf-8') as f:
        df_trn = f.read().split("\n")
    with open(args.data_path+"/"+ args.train_emotion_name, "r", encoding="utf-8") as f:
        emo_trn = json.load(f)
        
    with open(args.data_path + "/" + args.eval_file_name, "r", encoding='utf-8') as f:
        df_val = f.read().split("\n")
    with open(args.data_path+"/"+ args.eval_emotion_name, "r", encoding="utf-8") as f:
        emo_val = json.load(f)
        
    with open(args.data_path + "/" + args.test_file_name, "r", encoding='utf-8') as f:
        df_test = f.read().split("\n")
    with open(args.data_path+"/"+ args.test_emotion_name, "r", encoding="utf-8") as f:
        emo_test = json.load(f)
    
    args.eval_dataset = load_and_cache_examples(args, tokenizer, df_val, emo_val, evaluate=True, strategy=args.strategy, test=False)
    args.test_dataset = load_and_cache_examples(args, tokenizer, df_test, emo_test, evaluate=True, strategy=args.strategy, test=True)
   
    if args.do_train:
        
        # create output directory if needed
        os.makedirs(args.output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", args.output_dir)
        
        args.train_dataset = load_and_cache_examples(args, tokenizer, df_trn, emo_trn, evaluate=False, strategy=args.strategy)

        global_step, tr_loss = train(args, args.train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
        
        model = BlenderbotSmallForConditionalGeneration.from_pretrained(args.output_dir, from_tf=False)
        model.to(args.device)
        
        test_results = evaluate(args, model, tokenizer, args.test_dataset, "of test set")
        logging.info("Test dataset perplexity: %f", test_results['eval_perplexity'])

    else:
        generate(args)
        
def generate_backup(args):
    
    tokenizer = BlenderbotSmallTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.model_cache_dir)
    tokenizer.add_special_tokens({'cls_token': '[CLS]'})
    
    model = BlenderbotSmallForConditionalGeneration.from_pretrained(args.output_dir, from_tf=False)
    model.resize_token_embeddings(len(tokenizer))
    
     # Setup CUDA, GPU & distributed training
    if not args.no_cuda:
        #sigle GPU
        n_gpu = 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        args.n_gpu, args.device = n_gpu, device
    else:
        device = torch.device("cpu")
        args.device = device
        args.n_gpu = 0
        
    set_seed(args)
    
    with open(args.data_path+"/"+args.test_file_name,"r") as f:
        chat_texts = f.read().split("\n")
    
    with open(args.data_path+"/"+ args.test_emotion_name, "r", encoding="utf-8") as f:
        emo_test = json.load(f)
        
    model.eval()
    model.to(args.device)
    
    args.test_dataset = load_and_cache_examples(args, tokenizer, chat_texts, emo_test, evaluate=True, strategy=args.strategy, test=True)
    
    test_results = evaluate(args, model, tokenizer, args.test_dataset, "of test set")
    print('Test Results:', test_results)
    
    gts = []
    refs = []
    strategy_logit_str = []
    
    # Let's chat for 5 lines
    strategy_hits = []
    strategy_record = []
    strategy_hits_topk = [[] for _ in range(8)]
    
    for idx, (c_text, emo) in tqdm(enumerate(zip(chat_texts[:-1], emo_test)), desc="Testing"):
        
        chat_history = c_text
        f = construct_conv_ESC(args, idx, chat_history, emo['emotions'], tokenizer, strategy=False)
        
        next_strategy_id = f.decoder_strategy_ids[0]
        decoder_strategy_ids = torch.tensor([f.decoder_strategy_ids], dtype=torch.long)
        decoder_strategy_ids = decoder_strategy_ids.to(device)
        decoder_strategy_ids = decoder_strategy_ids[:, 0]
        strategy_ids = torch.tensor([f.strategy_ids], dtype=torch.long).to(device)
        input_ids = torch.tensor([f.input_ids], dtype=torch.long).to(args.device)

        gts.append(tokenizer.decode(f.decoder_input_ids, skip_special_tokens=True))

        paras = {}
        paras["attention_mask"] =  input_ids.ne(tokenizer.pad_token_id)
        paras['strategy_ids'] = strategy_ids
        paras["emotion_ids"] = torch.tensor([f.emotion_ids], dtype=torch.long).to(device)
        paras['decoder_input_ids'] = torch.tensor([f.decoder_input_ids], dtype=torch.long).to(device)
        paras["decoder_role_ids"] = torch.tensor([f.decoder_role_ids], dtype=torch.long).to(device)
        paras['role_ids'] = torch.tensor([f.role_ids], dtype=torch.long).to(device)
        # paras['labels'] = torch.tensor([f.decoder_lm_labels], dtype=torch.long).to(device)
        
        chat_history_ids, strategy_logits = model.generate(
            input_ids, 
            **paras, max_length=512, min_length=5, num_beams=1,
            pad_token_id=0, use_cache=True,
            eos_token_id=tokenizer.eos_token_id, cls_token_id=tokenizer.cls_token_id, temperature=0.7,
            top_p=0.4, top_k=20, do_sample=True, repetition_penalty=1.03) #top_p 0.9, topk 30
        
        chat_history_ids = chat_history_ids.cpu()
        
        refs.append(tokenizer.decode(chat_history_ids[:, :][0], skip_special_tokens=True))
    
        # id2strategy = {0: "[Question]", 1: "[Reflection of feelings]", 2: "[Information]", 3: "[Restatement or Paraphrasing]", 4: "[Others]", 5: "[Self-disclosure]", 6: "[Affirmation and Reassurance]", 7: "[Providing Suggestions]", 8: "[No Strategy]"}
        # strategy_record.append({"ref strategy": id2strategy[next_strategy_id],  "hyp strategy": id2strategy[strategy_logits[0][-1].argmax().item()]})
        
        for batch, label in enumerate(strategy_ids): # strategy_labels
            j = len(label) - 1
            
            while j >= 0:
                if label[j] != 8:
                    break
                j -= 1
            assert next_strategy_id == strategy_ids[batch][j] != 8
            if strategy_logits[batch][j].argmax() == strategy_ids[batch][j]:
                strategy_hits.append(1)
            else:
                strategy_hits.append(0)
                
        decoder_strategy_logits = strategy_logits[:, -1, :]
        
        for k in range(8):
            _, topk = decoder_strategy_logits[0].topk(k+1, -1)
            strategy_hits_topk[k].append(sum((topk == next_strategy_id).cpu().numpy().tolist()))
            
        decoder_strategy_logits = decoder_strategy_logits[0].cpu().numpy().tolist()
        decoder_strategy_logits = ["%.4f" % logit for logit in decoder_strategy_logits]
        strategy_logit_str.append('\t'.join(decoder_strategy_logits))
        
    for i in range(8):
        print(sum(strategy_hits_topk[i]) / len(strategy_hits_topk[i]))

    if not os.path.exists(args.generation_dir):
        os.makedirs(args.generation_dir)
        
    test_file_path = "dataset/testWithStrategy_short.tsv"
    test_situation_file_path = "dataset/testSituation.txt"
    strategy_record_file_path = os.path.join(args.generation_dir, "strategy_record.json")
    generate_file_path = os.path.join(args.generation_dir, "hyp_strategy.json")
    reference_file_path = os.path.join(args.generation_dir, "ref_strategy.json")
    summary_file_path = os.path.join(args.generation_dir, "summary.txt")
    strategy_logits_file = os.path.join(args.generation_dir, "strategy_logits.txt")
    
    with open(strategy_logits_file, "w", encoding="utf-8") as f:
        for item in strategy_logit_str:
            f.write(item + '\n')

    with open(strategy_record_file_path, "w",encoding="utf-8") as f:
        json.dump(strategy_record,f,indent=2,ensure_ascii=False)
    with open(generate_file_path, "w",encoding="utf-8") as f:
        json.dump(refs,f,indent=2,ensure_ascii=False)
    with open(reference_file_path,"w",encoding="utf-8") as f:
        json.dump(gts,f,indent=2,ensure_ascii=False)
        
    summary(test_file_path, generate_file_path, reference_file_path, summary_file_path, chat_texts, test_situation_file_path)

    print("write result to:", summary_file_path)
    print("Generate finished~")
    metric = Metric(toker=tokenizer, hyp_path=generate_file_path, ref_path=reference_file_path)
    result, result_list = metric.close()
    print(result)
    print("=" * 100)
    
    
def generate(args):
    
    tokenizer = BlenderbotSmallTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.model_cache_dir)
    tokenizer.add_special_tokens({'cls_token': '[CLS]'})
    
    model = BlenderbotSmallForConditionalGeneration.from_pretrained(args.output_dir, from_tf=False)
    model.resize_token_embeddings(len(tokenizer))
    
     # Setup CUDA, GPU & distributed training
    if not args.no_cuda:
        #sigle GPU
        n_gpu = 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        args.n_gpu, args.device = n_gpu, device
    else:
        device = torch.device("cpu")
        args.device = device
        args.n_gpu = 0
        
    set_seed(args)
    
    with open(args.data_path+"/"+args.test_file_name,"r") as f:
        chat_texts = f.read().split("\n")
    
    with open(args.data_path+"/"+ args.test_emotion_name, "r", encoding="utf-8") as f:
        emo_test = json.load(f)
        
    model.eval()
    model.to(args.device)
    
    args.test_dataset = load_and_cache_examples(args, tokenizer, chat_texts, emo_test, evaluate=True, strategy=args.strategy, test=True)
    # test_results = evaluate(args, model, tokenizer, args.test_dataset, "of test set")
    # print('Test Results:', test_results)
    
    gts = []
    refs = []

    for idx, (c_text, emo) in tqdm(enumerate(zip(chat_texts[:-1], emo_test)), desc="Testing"):
 
        chat_history = c_text
        f = construct_conv_ESC(args, idx, chat_history, emo['emotions'], tokenizer, evaluate=True, strategy=True)
        strategy_ids = torch.tensor([f.strategy_ids], dtype=torch.long).to(device)
        gts.append(tokenizer.decode(f.decoder_input_ids, skip_special_tokens=True)) #保存chat中的 要生成部分 也就是生成标签部分
        
        paras = {}
        input_ids = torch.tensor([f.input_ids], dtype=torch.long).to(args.device)
        paras["attention_mask"] =  input_ids.ne(tokenizer.pad_token_id).to(args.device)
        paras['strategy_ids'] = strategy_ids
        paras["emotion_ids"] = torch.tensor([f.emotion_ids], dtype=torch.long).to(args.device)
        paras["decoder_role_ids"] = torch.tensor([f.decoder_role_ids], dtype=torch.long).to(args.device)
        paras['role_ids'] = torch.tensor([f.role_ids], dtype=torch.long).to(args.device)
        
        
        print("输入的对话：", c_text)
        chat_history_ids = model.generate( #根据输入生成样本
            input_ids, 
            **paras, max_length=512, min_length=5, num_beams=1,
            pad_token_id=0, use_cache=True,
            eos_token_id=tokenizer.eos_token_id, temperature=0.7,
            top_p=0.4, top_k=20, do_sample=True, repetition_penalty=1.03) #top_p 0.9, topk 30
        
        chat_history_ids = chat_history_ids.cpu() #cls_token_id=tokenizer.cls_token_id
        print("EmotionalSupportGPT: {}".format(tokenizer.decode(chat_history_ids[:,:][0], skip_special_tokens=True)))
        
        refs.append(tokenizer.decode(chat_history_ids[:, :][0], skip_special_tokens=True))
        
    if not os.path.exists(args.generation_dir):
        os.makedirs(args.generation_dir)
    
    generate_file_path = os.path.join(args.generation_dir, "hyp_strategy.json")
    reference_file_path = os.path.join(args.generation_dir, "ref_strategy.json")
        
    with open(generate_file_path, "w", encoding="utf-8") as f:
        json.dump(refs, f, indent=2, ensure_ascii=False)
    with open(reference_file_path, "w", encoding="utf-8") as f:
        json.dump(gts, f, indent=2, ensure_ascii=False)

    print("Generate finished~")
    metric = Metric(toker=tokenizer, hyp_path=generate_file_path, ref_path=reference_file_path)
    result, result_list = metric.close()
    print(result)
    print("=" * 100)
    
def infernece():
    
    tokenizer = BlenderbotSmallTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.model_cache_dir)
    tokenizer.add_special_tokens({'cls_token': '[CLS]'})
    
    model = BlenderbotSmallForConditionalGeneration.from_pretrained(args.output_dir, from_tf=False)
    model.resize_token_embeddings(len(tokenizer))
    
     # Setup CUDA, GPU & distributed training
    if not args.no_cuda:
        #sigle GPU
        n_gpu = 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        args.n_gpu, args.device = n_gpu, device
    else:
        device = torch.device("cpu")
        args.device = device
        args.n_gpu = 0
        
    set_seed(args)
    model.eval()
    model.to(args.device)
    
    chat_history = ''
    start_turn = int(input("which turn do you want to start?(0-26)"))
    for step in range(20):
        usr_inp = input(">> User:")
        chat_history = chat_history + "1.0 0 " + str(start_turn) + " " + usr_inp + " EOS "
        if step > 2: #max num of turn > 5
            chat_history = "EOS".join(chat_history.split("EOS")[2:]).strip(" ")
        start_turn += 1
        # f = construct_conv_ESC(args, idx, chat_history, emo['emotions'], tokenizer, evaluate=True, strategy=True)
        # strategy_ids = torch.tensor([f.strategy_ids], dtype=torch.long).to(device)
        
    

        
if __name__ == "__main__":
    
    parser = ArgumentParser()
    
    parser.add_argument("--output_dir", type=str, default='./blender_strategy', help="Path of output dir")
    parser.add_argument("--generation_dir", type=str, default='./generated_data', help="Path of output dir")
    parser.add_argument("--model_type", type=str, default='mymodel')
    parser.add_argument("--use_emotion", action='store_true', default=False) 
    parser.add_argument("--use_bow", action='store_true', default=False)
    parser.add_argument("--model_name_or_path", type=str, default="./blenderbot_small-90M")
    parser.add_argument("--config_name", type=str, default="./blenderbot_small-90M")
    parser.add_argument("--tokenizer_name", type=str, default="./blenderbot_small-90M")
    parser.add_argument("--data_path", type=str, default="./dataset")
    parser.add_argument("--train_file_name", type=str, default="trainWithStrategy_short.tsv")
    parser.add_argument("--eval_file_name", type=str, default="devWithStrategy_short.tsv")
    parser.add_argument("--test_file_name", type=str, default="testWithStrategy_short.tsv")
    parser.add_argument("--train_emotion_name", type=str, default="train_emotion.json")
    parser.add_argument("--eval_emotion_name", type=str, default="dev_emotion.json")
    parser.add_argument("--test_emotion_name", type=str, default="test_emotion.json")

    parser.add_argument("--model_cache_dir", type=str, default="./blender-small")
    parser.add_argument("--data_cache_dir", type=str, default="./cached")
    parser.add_argument("--block_size", type=int, default=512)
    
    parser.add_argument("--do_train", type=bool, default=True)
    parser.add_argument("--do_eval", type=bool, default=True)
    parser.add_argument("--test", action='store_true', default=False)
    
    parser.add_argument("--generation", type=bool, default=True)
    
    parser.add_argument("--save", action='store_true', default=True)
    parser.add_argument("--evaluate_during_training", type=bool, default=True)
    
    parser.add_argument("--per_gpu_train_batch_size", type=int, default=20)
    parser.add_argument("--per_gpu_eval_batch_size", type=int, default=50)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)
    parser.add_argument("--max_grad_norm", type=float, default=1.0)
    parser.add_argument("--num_train_epochs", type=int, default=4)
    parser.add_argument("--max_steps", type=int, default=-1)
    parser.add_argument("--warmup_steps", type=int, default=120) # 100
    
    parser.add_argument("--rec_weight", type=int, default=1e2) 
    parser.add_argument("--ppl_weight", type=int, default=2)
    parser.add_argument("--causal_weight", type=int, default=1e-1)
    
    parser.add_argument("--logging_steps", type=int, default=200)
    parser.add_argument("--save_steps", type=int, default=30)
    parser.add_argument("--save_total_limit", type=int, default=None)
    parser.add_argument("--eval_all_checkpoints", type=bool, default=False)
    
    parser.add_argument("--no_cuda", type=bool, default=False)
    
    parser.add_argument("--overwrite_output_dir", type=bool, default=True)
    parser.add_argument("--overwrite_cache", type=bool, default=False)
    parser.add_argument("--should_continue", type=bool, default=False)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--fp16", type=bool, default=False)
    parser.add_argument("--fp16_opt_level", type=str, default='O1')
    parser.add_argument("--strategy", type=bool, default=False)
    parser.add_argument("--turn", type=bool, default=False)
    parser.add_argument("--role", type=bool, default=True)

    args = parser.parse_args()
    
    if args.test:
        generate(args)
    else:
        main(args)
